from ast import parse
import shutil
from src.bulletEnv import SACBulletEnv
from utils import env_utils
from stable_baselines3 import SAC
import os 
import pickle
from src.simulator import Simulator
import yaml
from src.data_structures import SingleValueSampler
import time
from src.ll_control import PolicyTrainer as Train
import numpy as np
from stable_baselines3.common.noise import NormalActionNoise, VectorizedActionNoise
from src.custom_callbacks import CustomEval, CustomEarlyStop
from stable_baselines3.common.callbacks import EvalCallback, CallbackList 
from stable_baselines3.common.vec_env import SubprocVecEnv
from src.logger import SummaryWriterCallback
import sys

runid = sys.argv[1]
seed = 1337

def parse_yaml(yaml_file):
    with open(yaml_file,"r") as stream:
        config = yaml.safe_load(stream)
        return config

def make_env(env):
    def _init():
        return env
    return _init

def make_subproc_envs(num,init_config,goal_config,robot_config,env_path,max_ep_len,gui,seed, problem_number):
    envs = []
    for i in range(num):
        envs.append(make_env(SACBulletEnv(
                init_config = init_config,
                goal_config = goal_config,
                robot_config = robot_config, 
                env_path = env_path,
                max_ep_len = max_ep_len,
                gui = gui, 
                seed = seed,
                envid=i+1, 
                problem_number = problem_number)))
    subproc_envs = SubprocVecEnv(envs, start_method='spawn')
    return subproc_envs




def train(config_params, init_config, goal, tblog_prefix, problem_number):
    envs = make_subproc_envs(num=config_params['region_policy']['train_envs'],
                                                     gui=config_params['trainer_gui'],
                                                     seed=seed,
                                                     init_config=init_config,
                                                     goal_config=goal,
                                                     robot_config=config_params['robot'],
                                                     env_path=os.path.join(config_params['env_path'],config_params['env_name']+".stl"),
                                                     max_ep_len=config_params["region_policy"]['max_ep_len'], 
                                                     problem_number=problem_number)

    time.sleep(5)


    eval_env = make_subproc_envs(num=1,
                                                     gui=True,
                                                     seed=seed,
                                                     init_config=init_config,
                                                     goal_config=goal,
                                                     robot_config=config_params['robot'],
                                                     env_path=os.path.join(config_params['env_path'],config_params['env_name']+".stl"),
                                                     max_ep_len=config_params["region_policy"]['max_ep_len'], 
                                                     problem_number=problem_number)

    time.sleep(5)

    n_actions = envs.get_attr('action_space')[0].shape[-1]
    action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))

    tensorboard_log_dir = "./baselines_logs/sac/{}/{}/".format(tblog_prefix, time.strftime("%d%y%h_%H%M"))

    model = SAC("MlpPolicy",
                    env= envs,
                    train_freq = (config_params["region_policy"]["train_freq"],"step"),
                    gradient_steps =  config_params["region_policy"]["update_batch"],
                    batch_size = config_params["region_policy"]["batch_size"],
                    buffer_size = config_params["region_policy"]["buffer_size"],
                    policy_kwargs = dict(net_arch = [32,32]),
                    learning_rate = 0.003,
                    action_noise = action_noise,
                    tensorboard_log= tensorboard_log_dir
                )
    log_name = envs.env_method("make_log_dir")[0]
    callback_on_best = CustomEarlyStop(reward_threshold=200, max_no_improvement_evals=5, min_evals=2, verbose = 0)

    callback_on_best = CustomEarlyStop(
            reward_threshold=200, 
            max_no_improvement_evals=5, 
            min_evals=2, 
            verbose = 0)

    best_model_path = os.path.join('./sac_eval/best/',tblog_prefix)
    eval_callback = EvalCallback(eval_env=eval_env,
                                     n_eval_episodes=100,
                                     callback_on_new_best=callback_on_best,
                                     eval_freq=problem_config['region_policy']['eval_freq'],
                                     best_model_save_path=best_model_path,
                                     verbose=1)

    callbacks = CallbackList([eval_callback, SummaryWriterCallback()])
    
    model.learn( total_timesteps = problem_config["region_policy"]              ["train_timesteps"], log_interval = 10, tb_log_name = "{}".format(log_name), reset_num_timesteps = False, callback = callbacks)
    training_steps = model.num_timesteps

    path = os.path.join(config_params["policy_folder"],
                                  time.strftime('%d%y%h_%H%M%S')+'.pt')
    best_model = os.path.join(best_model_path, "best_model.zip")
    if os.path.isfile(best_model):
            shutil.copyfile(best_model,path)
    else:
        model.save(path)
    
    return path, training_steps

def generate_and_test(problem_config):
        experiment_fname = '{}_{}.pickle'.format(problem_config['robot']['name'], problem_config['env_name'])
        experiment = os.path.join(problem_config['experiments_path'],experiment_fname)
        with open(experiment, 'rb') as f:
            load_locations = pickle.load(f)
        policy_log_dir = os.path.join(problem_config['policy_logs'],
                                           problem_config['robot']['name'],
                                           problem_config['env_name'],
                                           runid)
        for i in range(problem_config["num_test_problems"]):               # TODO remove the slice and make th
            init = load_locations['init'][i][:3] 
            goal = load_locations['goal'][i][:3] 
            
            if "small" not in problem_config["env_name"]:
                for k in range(len(init) - 1):
                    init[k] = init[k] * 3 
                    goal[k] = goal[k] * 3
            init_sampler = SingleValueSampler(init)
            term_sampler = SingleValueSampler(goal)
            print("Running problem #{0}".format(i+1))

            prefix = problem_config["env_name"] + "_" + str(i+1)
            path, steps = train(problem_config,init,goal,prefix,i+1)

            with open("./baselines/results.txt","a") as f:
                f.write("{},{},{},{}\n".format(problem_config["env_name"],i+1,path,steps))
            break

        return None

if __name__ == "__main__":
    problem_config = parse_yaml("baselines_parameters.yaml")
    generate_and_test(problem_config)